import sys
import uuid
import math
import conf.settings as settings
from clogger import logger
from enum import Enum
from time import time
from typing import List, Dict, Optional
from multiprocessing import Queue
from queue import Full
from scipy.interpolate import interp1d
from sysom_utils import SysomFramework
from metric_reader import RangeQueryTask, InstantQueryTask, MetricReader
from dataclasses import dataclass
from lib.metric_exception import MetricSettingsException
from lib.metric_exception import MetricCollectException
from lib.metric_exception import MetricProcessException

CLUSTER_LABEL = settings.CLUSTER_LABEL
POD_LABEL = settings.POD_LABEL
NODE_LABEL = settings.NODE_LABEL
POD_METRIC_TAG = settings.POD_METRIC_TAG


class Level(Enum):
    Cluster = 0
    Node = 1
    Pod = 2


class RangeAggregationType(Enum):
    Increase = 0
    Rate = 1
    Irate = 2
    MAX_OVER_TIME = 3
    AVG_OVER_TIME = 4


class InsAggregationType(Enum):
    Sum = 0
    Max = 1
    Avg = 2


@dataclass
class Collect:
    metric_name: str
    related_value: List[str]
    standard_type: int
    node_tag_name: Optional[str] = None
    filename: Optional[str] = None


@dataclass
class Score:
    weight: float
    score: Dict[str, int]


@dataclass
class Alarm:
    threshold: int
    diagnose_type: str
    diagnose_url: Optional[str] = None
    service_name: Optional[str] = None


@dataclass
class MetricSettings:
    description: str
    collect: Collect
    score: Score
    alarm: Optional[Alarm]


@dataclass
class DiagnoseInfo:
    alarm_id: str
    service_name: str
    type: str
    level: Level
    metric_description: str
    instance: str
    pod: Optional[str] = None
    container: Optional[str] = None
    time: Optional[str] = None
    diagnose_type: Optional[str] = None


class Metric():
    def __init__(self, metric_reader: MetricReader,
                 metric_settings, level: Level):
        self.metric_reader = metric_reader
        self.level = level
        self.name = {}  # self.name[self.level] = cluster/node/pod name
        self.last_end_time = 0
        self.score_interp = None
        self.settings = None

        try:
            self._initalize_settings(metric_settings)
        except BaseException:
            raise

    ##########################################################################
    # Inner funtions
    ##########################################################################

    def _initalize_settings(self, settings):
        try:
            self.settings = MetricSettings(
                description=settings["Description"],
                collect=Collect(**settings["Collect"]),
                score=Score(**settings["Score"]),
                alarm=Alarm(
                    **settings["Alarm"]) if settings.get("Alarm") else None
            )

            if self.level == Level.Node or self.level == Level.Cluster:
                if not self.settings.collect.node_tag_name:
                    raise MetricSettingsException(
                        f"node_tag_name must set "
                        f"in {self.settings.description}!"
                    )

            if self.settings.alarm is not None:
                if self.settings.alarm.diagnose_type == "link":
                    if not self.settings.alarm.diagnose_url:
                        raise MetricSettingsException(
                            f"diagnose_url must set "
                            f"in {self.settings.description}!"
                        )

                if self.settings.alarm.diagnose_type == "custom":
                    if not self.settings.alarm.service_name:
                        raise MetricSettingsException(
                            f"service_name must set "
                            f"in {self.settings.description}!"
                        )

            self._initalize_score_settings(self.settings.score.score)
        except Exception as exc:
            raise MetricSettingsException() from exc

    def _initalize_score_settings(self, score_setting):
        X = []
        Y = []

        for score, metric_value in score_setting.items():
            Y.append(int(score))
            X.append(metric_value)

        # in early version of scipy, X[0] can't be 0
        if X[0] == 0:
            X[0] = -sys.float_info.epsilon

        # 分数随指标值增加而下降，在头尾补上极端值
        X.insert(0, -sys.float_info.epsilon)
        Y.insert(0, 100)
        X.append(sys.maxsize)
        Y.append(0)
        self.score_interp = interp1d(X, Y)

    def _get_custom_metric(self, metric_name: str, **kwargs):
        task = RangeQueryTask(metric_name,
                              start_time=self.last_end_time,
                              end_time=time())
        for key, value in kwargs.items():
            task.append_equal_filter(key, value)

        return self.metric_reader.range_query([task])

    def _aggregation(self, data: List[float],
                     aggre: InsAggregationType) -> float:
        if aggre == InsAggregationType.Sum:
            return sum(data)
        elif aggre == InsAggregationType.Max:
            return max(data)
        elif aggre == InsAggregationType.Avg:
            return sum(data) / len(data)

    def _default_single_gauge(
        self,
        ins_agg_type: InsAggregationType = InsAggregationType.Max
    ) -> float:
        """Collect and process one gauge metric(max, max)

        对于采集的指标是gauge类型：
            value = max(range）
            pod = max(containers)
            cluster = max(nodes)

        final_value = max(max(query_result))
        """

        query_args = {}
        node_tag = self.settings.collect.node_tag_name
        val = self.settings.collect.related_value[0]

        if self.level == Level.Node:
            query_args = {
                NODE_LABEL: self.name[Level.Node],
                node_tag: val
            }
        elif self.level == Level.Pod:
            query_args = {
                NODE_LABEL: self.name[Level.Node],
                POD_LABEL: self.name[Level.Pod],
                POD_METRIC_TAG: val
            }
        else:
            query_args = {
                CLUSTER_LABEL: self.name[Level.Cluster],
                node_tag: val
            }
            pass

        res = self._get_custom_metric(
            self.settings.collect.metric_name, **query_args)
        if len(res.data) <= 0:
            raise MetricCollectException(
                f"Collect {self.settings.collect.metric_name}, Level: "
                f"{self.level} from Prometheus: no data!"
            )
        # print(json.dumps(res.to_dict()))
        try:
            max_values = []
            # 对于容器指标：多个data表示是同一个pod的多个容器
            # 对于节点指标：应该只有一个data
            # 对于集群指标：多个data表示集群中的多个节点
            for i in range(len(res.data)):
                values = res.data[i].to_dict()["values"]
                # 取区间向量所有点的的最大值
                max_value = max(float(value[1]) for value in values)
                max_values.append(max_value)
            final_value = self._aggregation(max_values, ins_agg_type)
        except Exception as exc:
            raise MetricProcessException() from exc

        return final_value

    def _default_single_counter(
        self,
        related_value: str,
        range_agg_type: RangeAggregationType,
        ins_agg_type: InsAggregationType = InsAggregationType.Sum
    ) -> float:
        """Collect and process one counter metric(increase/rate/irate, max)
        """

        metric_name = self.settings.collect.metric_name
        node_tag = self.settings.collect.node_tag_name

        query_interval = int(time() - self.last_end_time)
        if query_interval < 60:
            query_interval = 60
        query_interval_str = f"{query_interval}s"

        aggr_str = range_agg_type.name.lower()
        if self.level == Level.Pod:
            task = InstantQueryTask(metric_name,
                                    time=time(), aggregation=aggr_str,
                                    interval=query_interval_str) \
                .append_equal_filter(NODE_LABEL, self.name[Level.Node]) \
                .append_equal_filter(POD_LABEL, self.name[Level.Pod]) \
                .append_equal_filter(POD_METRIC_TAG, related_value)
        elif self.level == Level.Node:
            task = InstantQueryTask(metric_name,
                                    time=time(), aggregation=aggr_str,
                                    interval=query_interval_str) \
                .append_equal_filter(NODE_LABEL, self.name[Level.Node]) \
                .append_equal_filter(node_tag, related_value)
        else:
            # cluster level
            task = InstantQueryTask(metric_name,
                                    time=time(), aggregation=aggr_str,
                                    interval=query_interval_str) \
                .append_equal_filter(CLUSTER_LABEL, self.name[Level.Cluster]) \
                .append_equal_filter(node_tag, related_value)

        res = self.metric_reader.instant_query([task])
        if len(res.data) <= 0:
            raise MetricCollectException(
                f"Collect {metric_name}, Value: {related_value},"
                f"Level: {self.level} from Prometheus failed: no data!"
            )
        # logger.info(json.dumps(res.to_dict()))
        final_value = None
        try:
            # 一般情况下，对于容器指标：data[0],data[1]..表示是同一个pod的多个容器
            # 区间向量已经通过promql的increase函数聚合，只需要将容器指标聚合成pod指标即可
            # 注意instant_query返回的值的key是"value"
            values = [float(res.data[i].to_dict()["value"][1])
                      for i in range(len(res.data))]

            final_value = self._aggregation(values, ins_agg_type)

        except Exception as exc:
            raise MetricProcessException() from exc

        return final_value

    def _collect_process_metric(self) -> float:
        """Collect metric and preprocess metrics from prometheus,
        return a value to calculate score

        Returns:
            float: metric value after collect and preprocess
        """
        raise NotImplementedError("_collect_process_metric not implememted!")

    def _calculate_score(self, metric_value: float) -> float:
        res = self.score_interp(metric_value)
        # res is numpy.ndarry, convert to float
        score = round(float(res.tolist()), 2)
        return score

    ##########################################################################
    # Outer funtions
    ##########################################################################

    def deliver_alarm(self, metric_value: float, type: str) -> str:
        alarm_uuid = uuid.uuid4()
        metric_value = round(metric_value, 2)

        SysomFramework.alarm({
            "alert_id": str(alarm_uuid),
            "instance": self.name[self.level],
            "alert_item": self.settings.description,
            "alert_category": "MONITOR",
            "alert_source_type": "health check",
            "alert_time": int(round(time() * 1000)),
            "status": "FIRING",
            "labels": {
                "cluster": self.name[Level.Cluster],
                "node": self.name[Level.Node],
                "pod": self.name[Level.Pod],
                "metric_type": type,
            },
            "annotations": {
                "summary": f"{self.settings.description} has low score with"
                f" value {metric_value}"
            }
        })

        return str(alarm_uuid)

    def deliver_diagnose(self, alarm_id: str, level: Level,
                         type: str, queue: Queue):
        """Deliver diagnose info to diagnose worker

        Args:
            alarm_id: alarm id
            level: level of this metric
            type: metric type
            queue: diagnose queue
        """

        diagnose_type = self.settings.alarm.diagnose_type

        if diagnose_type == "link":
            SysomFramework.alarm_action("ADD_OPT", {
                "alert_id": str(alarm_id),
                "opt": {
                    "key": self.settings.description,
                    "label": self.settings.description,
                    "type": "LINK",
                    "url": self.settings.alarm.diagnose_url
                }
            })
        elif diagnose_type == "custom":
            try:
                queue.put(
                    DiagnoseInfo(
                        alarm_id=str(alarm_id),
                        service_name=self.settings.alarm.service_name,
                        type=type,
                        level=level,
                        metric_description=self.settings.description,
                        instance=self.name[Level.Node],
                    ),
                    block=False
                )
            except Full:
                logger.error(f"Diagnose queue is full!, "
                             f"drop alarm {alarm_id}'s diagnose!")
            except Exception as e:
                logger.error(f"Deliver diagnose info of alarm {alarm_id} "
                             f"to diagnose worker failed: {e}")
                pass

    def construct_diagnose_req(
            self, diagnose_info: DiagnoseInfo) -> Dict[str, str]:
        """Construct diagnose request to query sysom diagnosis center

        Args:
            diagnose_info: diagnose info

        Returns:
            Dict: query request
        """
        raise NotImplementedError("construct_diagnose_req not implememted!")

    def process_diagnose_req(self, result):
        """Process diagnose result from sysom diagnosis center

        Args:
            result: diagnose result
        """
        raise NotImplementedError("process_diagnose_req not implememted!")

    def metric_score(self, pod: str, node: str,
                     cluster: str, last_end_time: float) -> (float, float):
        """Calculate the final score of this metric

        Args:
            pod: pod name
            node: node name
            cluster: cluster name
            last_end_time: end time of last calculate interval

        Raises:
            MetricCollectException
            MetricSettingsException
            NotImplementedError

        Returns:
            (metric_value, score): metric value and score after calculation
        """
        self.name[Level.Pod] = pod
        self.name[Level.Node] = node
        self.name[Level.Cluster] = cluster
        self.last_end_time = last_end_time

        metric_value = None
        score = None
        try:
            metric_value = self._collect_process_metric()
            score = self._calculate_score(metric_value)
            score = math.floor(score)
        except Exception as e:
            err = f"Calculate metric: {self.settings.collect.metric_name} " + \
                  f"score failed: {str(e)}!"
            raise MetricProcessException(err) from e

        return metric_value, score
